package cn.lingyangwl.framework.core.utils;

import cn.lingyangwl.framework.tool.core.exception.Assert;
import cn.lingyangwl.framework.tool.core.exception.AsyncTaskException;
import org.apache.commons.lang3.time.StopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.stereotype.Component;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Function;

/**
 * 异步任务执行器, 有如下功能
 * 1. 可以添加多个任务, 并发执行, 主线程会等待所有任务执行完毕才会往下执行
 * 2. 可以多线程同时操作数据库, eg: 保存或者更新数据
 * 3. 支持事务
 *
 * @author shenguangyang
 * @apiNote 任务的方法上不需要添加 事务注解
 */
@Component
@ConditionalOnClass({TransactionDefinition.class, DataSourceTransactionManager.class})
public class AsyncTaskExecutorWithTx {
    private static final Logger log = LoggerFactory.getLogger(AsyncTaskExecutorWithTx.class);
    private static TransactionDefinition transactionDefinition;

    private static DataSourceTransactionManager dataSourceTransactionManager;

    /**
     * 带有回滚的任务集合
     */
    private final List<Task> taskList = new CopyOnWriteArrayList<>();

    /**
     * 如果带有回滚的任务执行失败, 则该值变成false, 并且所有带有回滚的任务都将被回滚
     */
    private final AtomicBoolean taskExeFail = new AtomicBoolean(false);

    private Executor executor;

    @Autowired
    public AsyncTaskExecutorWithTx(TransactionDefinition transactionDefinition, DataSourceTransactionManager dataSourceTransactionManager) {
        AsyncTaskExecutorWithTx.transactionDefinition = transactionDefinition;
        AsyncTaskExecutorWithTx.dataSourceTransactionManager = dataSourceTransactionManager;
    }

    private AsyncTaskExecutorWithTx() {
    }

    public static AsyncTaskExecutorWithTx init(Executor executor) {
        Assert.notNull(executor, "executor is null");
        AsyncTaskExecutorWithTx asyncTaskExecutor = new AsyncTaskExecutorWithTx();
        asyncTaskExecutor.executor = executor;
        return asyncTaskExecutor;
    }

    /**
     * 添加任务
     *
     * @param preProcessorIn     前置处理器输入数据
     * @param preWithTxProcessor 接受输入的数据, 并做处理, 然后返回一个结果(带有事务的处理)
     * @param postProcessor      接受前置处理的结果进行再次加工
     * @param <T>                输入
     * @param <R>                前置处理的输出数据
     * @return this
     */
    public <T, R> AsyncTaskExecutorWithTx addTask(T preProcessorIn, Function<T, R> preWithTxProcessor, Consumer<R> postProcessor) {
        Assert.notNull(preWithTxProcessor, "preWithTxProcessor is null");
        Assert.notNull(postProcessor, "postProcessor is null");
        taskList.add(() -> {
            R postProcessIn = preWithTxProcessor.apply(preProcessorIn);
            postProcessor.accept(postProcessIn);
        });
        return this;
    }

    /**
     * 添加任务
     *
     * @param processorIn     前置处理器输入数据
     * @param withTxProcessor 接收输入数据进行加工处理(带有事务的处理器)
     * @param <T>             输入
     * @return this
     */
    public <T> AsyncTaskExecutorWithTx addTask(T processorIn, Consumer<T> withTxProcessor) {
        Assert.notNull(withTxProcessor, "withTxProcessor is null");
        taskList.add(() -> withTxProcessor.accept(processorIn));
        return this;
    }


    /**
     * 并发执行任务
     *
     * @throws AsyncTaskException 异步任务异常
     */
    public void execute() {
        // 校验事务管理器是否已经被注入
        Assert.notNull(transactionDefinition, "transactionDefinition is null");
        Assert.notNull(dataSourceTransactionManager, "dataSourceTransactionManager is null");

        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        try {
            // 等待所有任务执行完成
            CountDownLatch taskExeEndLatch = new CountDownLatch(taskList.size());
            // 等待所有线程事务执行完成
            CountDownLatch taskTxLatch = new CountDownLatch(taskList.size());
            for (Task task : taskList) {
                executor.execute(() -> {
                    // 手动获取一个事务
                    TransactionStatus transactionStatus = dataSourceTransactionManager.getTransaction(transactionDefinition);
                    try {
                        task.runTask();
                    } catch (Exception e) {
                        log.error("error: {}", e.getMessage());
                        taskExeFail.set(true);
                        taskExeEndLatch.countDown();
                    } finally {
                        taskExeEndLatch.countDown();
                        try {
                            taskExeEndLatch.await();
                            if (taskExeFail.get()) {
                                // 手动回滚事务
                                log.debug("rollback task thread {}", Thread.currentThread().getName());
                                dataSourceTransactionManager.rollback(transactionStatus);
                            } else {
                                // 手动提交事务
                                dataSourceTransactionManager.commit(transactionStatus);
                            }
                        } catch (Exception e) {
                            log.error("error: {}", e.getMessage());
                        } finally {
                            taskTxLatch.countDown();
                        }
                    }
                });
            }
            taskTxLatch.await();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        stopWatch.stop();
        long time = stopWatch.getTime(TimeUnit.MILLISECONDS);
        log.debug("async task total time: {} ms", time);
        if (taskExeFail.get()) {
            throw new AsyncTaskException("async task exe fail");
        }
    }

    public interface Task {
        void runTask();
    }
}
